import itertools
import json
import sys
from collections import defaultdict
from typing import NamedTuple, Tuple, Dict, List, FrozenSet, Iterable, Callable, Set

import numpy as np
import tqdm.auto as tqdm

VISIBLE_LABELS = ["rel_pos_x", "rel_pos_y"]
ACTION_LABELS = ["agent_0_action", "agent_1_action"]
HIDDEN_LABELS = ["agent_0_xmove", "agent_0_ymove", "agent_1_xmove", "agent_1_ymove"]


class HiddenValueShieldState(NamedTuple):
    label: Tuple[int, ...]
    prev_actions: Tuple[int, ...]
    next_states: FrozenSet[int]
    hidden_values: Tuple[int, ...]
    violation: bool


class OutgoingActionShieldState(NamedTuple):
    label: Tuple[int, ...]
    initial_state: bool
    actions: Dict[Tuple[int, ...], FrozenSet[int]]  # Action -> List of states
    hidden_values: Tuple[Set[int], ...]

    def __hash__(self):
        return hash((self.label, frozenset(self.actions.items()), self.initial_state))

    def __eq__(self, other):
        return isinstance(other,
                          OutgoingActionShieldState) and self.label == other.label and self.initial_state == other.initial_state and self.actions == other.actions


ShieldSpec = Dict[int, OutgoingActionShieldState]


class ActionPermutation(NamedTuple):
    """
    As long as all agents select the same action permutation at the same time (e.g. using a PRNG with the same seed),
    the resulting joint action will be safe for everyone.
    """
    agent_order: Tuple[int, ...]  # Doesn't actually matter at runtime, this is just for info
    actions: List[Tuple[int, ...]]
    next_states: FrozenSet[int]


class DecentralizedShieldState(NamedTuple):
    label: Tuple[int, ...]
    initial_state: bool
    action_permutations: List[ActionPermutation]
    hidden_values: Tuple[Set[int], ...]


DecentralizedShieldSpec = Dict[int, DecentralizedShieldState]


def move_actions_to_outgoing_transitions(input_dict: Dict[int, HiddenValueShieldState],
                                         initial_cond: Callable[[HiddenValueShieldState], bool]) -> Dict[
    int, OutgoingActionShieldState]:
    output_dict: Dict[int, OutgoingActionShieldState] = {}

    for state_num, shield_state in tqdm.tqdm(input_dict.items(), desc="Processing next-state actions"):
        outgoing_actions = defaultdict(set)
        for next_state in shield_state.next_states:
            outgoing_actions[input_dict[next_state].prev_actions].add(next_state)

        frozen_outgoing_actions = dict(
            (action, frozenset(next_states)) for action, next_states in outgoing_actions.items())
        output_dict[state_num] = OutgoingActionShieldState(shield_state.label, initial_cond(shield_state),
                                                           frozen_outgoing_actions,
                                                           tuple({i} for i in shield_state.hidden_values))

    return output_dict


def combine_identical_states(input_dict: ShieldSpec) -> ShieldSpec:
    intermediate_shield_states: List[OutgoingActionShieldState] = []
    shield_state_reverse_mapping: Dict[OutgoingActionShieldState, Tuple[int, Tuple[Set[int], ...]]] = {}
    input_shield_num_mapping: Dict[int, int] = {}

    # Find all states which have the same
    for input_shield_num, input_shield_state in tqdm.tqdm(input_dict.items(), desc="Finding identical states"):
        if input_shield_state in shield_state_reverse_mapping:
            corresponding_new_shield_state, collected_hidden_values = shield_state_reverse_mapping[input_shield_state]
            for hid_val_set, this_hidden_value in zip(collected_hidden_values, input_shield_state.hidden_values):
                hid_val_set.update(this_hidden_value)
        else:
            corresponding_new_shield_state = len(intermediate_shield_states)
            intermediate_shield_states.append(input_shield_state)
            shield_state_reverse_mapping[input_shield_state] = (
            corresponding_new_shield_state, input_shield_state.hidden_values)

        input_shield_num_mapping[input_shield_num] = corresponding_new_shield_state

    def remap_single_shield_state(shield_state: OutgoingActionShieldState) -> OutgoingActionShieldState:
        return OutgoingActionShieldState(label=shield_state.label,
                                         actions=dict(
                                             (action, frozenset(
                                                 input_shield_num_mapping[next_state] for next_state in next_states))
                                             for action, next_states in shield_state.actions.items()),
                                         initial_state=shield_state.initial_state,
                                         hidden_values=shield_state.hidden_values)

    intermediate_shield_states_remapped = list(map(remap_single_shield_state, intermediate_shield_states))

    return dict(enumerate(intermediate_shield_states_remapped))


def eliminate_bad_states(input_dict: ShieldSpec) -> ShieldSpec:
    bad_states = set()
    t = tqdm.tqdm(total=len(input_dict) * 2, desc="Removing deadlock states")
    for state_num, state in input_dict.items():
        t.update(1)
        if len(state.actions) == 0:
            bad_states.add(state_num)

    output_dict = {}
    for state_num, state in input_dict.items():
        t.update(1)
        if state_num in bad_states:
            continue

        outgoing_actions = dict((action, successors) for action, successors in state.actions.items() if
                                len(successors.intersection(bad_states)) == 0)
        output_dict[state_num] = OutgoingActionShieldState(state.label, state.initial_state, outgoing_actions,
                                                           state.hidden_values)

    return output_dict


def decentralize_shield(input_dict: ShieldSpec, action_space_dims) -> DecentralizedShieldSpec:
    dec_shield: DecentralizedShieldSpec = dict()
    agent_permutations = list(itertools.permutations(range(len(action_space_dims))))
    for state_num, state_info in tqdm.tqdm(input_dict.items(), desc="Decentralizing shield"):
        action_perms = []
        for agent_ordering in agent_permutations:
            zero_action = tuple(0 for _ in action_space_dims)
            if zero_action in state_info.actions:
                safe_starting_action = zero_action
            else:
                safe_starting_action = next(iter(state_info.actions))

            allowed_indiv_actions = tuple(np.zeros(num_actions, dtype=bool) for num_actions in action_space_dims)
            for aia, ssa in zip(allowed_indiv_actions, safe_starting_action):
                aia[ssa] = True

            next_label_mapping: Dict[Tuple[int], int] = {}

            def add_state_set_to_next_label_mapping(state_set: Iterable[int]):
                temp_mapping = dict(next_label_mapping)
                for state_num_to_add in state_set:
                    state_to_add = input_dict[state_num_to_add]
                    if state_to_add.label in temp_mapping and temp_mapping[state_to_add.label] != state_num_to_add:
                        # This label was already taken
                        return False
                    else:
                        temp_mapping[state_to_add.label] = state_num_to_add

                next_label_mapping.update(temp_mapping)
                return True

            add_state_set_to_next_label_mapping(state_info.actions[safe_starting_action])

            def try_adding_agent_individual_action(agent_num, action_to_try):

                # What joint actions would adding this individual action imply?
                # And what states would this lead to?
                enabled_indiv_actions = tuple((action_to_try,) if other_agent_num == agent_num else tuple(
                    np.flatnonzero(allowed_indiv_actions[other_agent_num]))
                                              for other_agent_num in range(len(action_space_dims)))

                next_state_set = set()
                for newly_enabled_joint_action in itertools.product(*enabled_indiv_actions):
                    if newly_enabled_joint_action not in state_info.actions:
                        return  # This individual action would not be safe to add

                    next_state_set.update(state_info.actions[newly_enabled_joint_action])

                if add_state_set_to_next_label_mapping(next_state_set):
                    allowed_indiv_actions[agent_num][action_to_try] = True

            for agent_num in agent_ordering:
                for action_to_try in range(action_space_dims[agent_num]):
                    try_adding_agent_individual_action(agent_num, action_to_try)

            action_perms.append(ActionPermutation(agent_order=agent_ordering,
                                                  actions=[tuple(int(x) for x in np.flatnonzero(agent_indiv_actions))
                                                           for agent_indiv_actions in allowed_indiv_actions],
                                                  next_states=frozenset(next_label_mapping.values())))

        dec_shield[state_num] = DecentralizedShieldState(state_info.label, state_info.initial_state, action_perms,
                                                         state_info.hidden_values)

    return dec_shield


def parse_shield(shield_json, visible_label: List[str], action_label: List[str], hidden_label: List[str]):
    def parse_state(state_json):
        state_vars = state_json["State"]
        visible_vars = tuple(int(state_vars[label]) for label in visible_label)
        action_vars = tuple(int(state_vars[label]) for label in action_label)
        hidden_vars = tuple(int(state_vars[label]) for label in hidden_label)
        violation = state_vars["violation"] == 1

        return HiddenValueShieldState(label=visible_vars, prev_actions=action_vars, hidden_values=hidden_vars,
                                      next_states=frozenset(state_json["Successors"]), violation=violation)

    return {int(state_num): parse_state(state_json) for state_num, state_json in shield_json.items()}


def outgoing_action_shield_to_json(list_of_states: ShieldSpec, visible_label: List[str],
                                   action_label: List[str], hidden_label: List[str]):
    output_json = {}
    for state_num, state in list_of_states.items():
        state_var_json = dict(zip(visible_label, state.label))
        actions = [{
            "action": dict(zip(action_label, action)),
            "successors": list(following_states)
        } for action, following_states in state.actions.items()]

        output_json[str(state_num)] = {
            "state": state_var_json,
            "actions": actions,
            "initial": state.initial_state,
            "hidden": dict(zip(hidden_label, map(list, state.hidden_values)))
        }

    return output_json


def dec_shield_to_json(shield: DecentralizedShieldSpec, visible_label: List[str], hidden_label: List[str]):
    output_json = {}
    for state_num, state in shield.items():
        state_var_json = dict(zip(visible_label, state.label))

        permutations = [{
            "agent_order": permutation.agent_order,
            "allowed_actions": permutation.actions,
            "next_states": list(permutation.next_states)
        } for permutation in state.action_permutations]

        output_json[str(state_num)] = {
            "state": state_var_json,
            "actions": permutations,
            "initial": state.initial_state,
            "hidden": dict(zip(hidden_label, map(list, state.hidden_values)))
        }

    return output_json


def load_centralized_shield(path: str) -> ShieldSpec:
    with open(path + ".shield2", "r") as file:
        shield_json = json.load(file)

    ret = {}
    for shield_state_num, shield_state_json in shield_json.items():
        label = tuple(int(shield_state_json["state"][label]) for label in VISIBLE_LABELS)
        actions = {}
        for action in shield_state_json["actions"]:
            action_label = tuple(int(action["action"][action_label]) for action_label in ACTION_LABELS)
            successors = frozenset(int(a) for a in action["successors"])
            actions[action_label] = successors

        ret[int(shield_state_num)] = OutgoingActionShieldState(label, shield_state_json["initial"], actions, ())

    return ret


def load_decentralized_shield(path: str) -> DecentralizedShieldSpec:
    with open(path + ".dec_shield", "r") as file:
        shield_json = json.load(file)

    ret = {}
    for shield_state_num, shield_state_json in shield_json.items():
        label = tuple(int(shield_state_json["state"][label]) for label in VISIBLE_LABELS)
        agent_orders = []
        for agent_order in shield_state_json["actions"]:
            agent_orders.append(ActionPermutation(agent_order=tuple(int(a) for a in agent_order["agent_order"]),
                                                  actions=list(tuple(int(a) for a in indiv_actions) for indiv_actions in
                                                               agent_order["allowed_actions"]),
                                                  next_states=frozenset(int(a) for a in agent_order["next_states"])))

        ret[int(shield_state_num)] = DecentralizedShieldState(label, shield_state_json["initial"], agent_orders, ())

    return ret


if __name__ == '__main__':
    name = sys.argv[1]

    with open(name + ".shield", "r") as file:
        shield_json = json.load(file)

    initial_cond = lambda _: True

    if "momentum" in name:
        HIDDEN_LABELS = ["agent_0_ximp", "agent_0_yimp", "agent_1_ximp", "agent_1_yimp", "rel_xmove", "rel_ymove",
                         "prev_rel_xmove", "prev_rel_ymove"]


        def initial_cond(shield_info: HiddenValueShieldState):
            (xi0, yi0, xi1, yi1, rxm, rym, prxm, prym) = shield_info.hidden_values
            return prxm == 2 and prym == 2 and rxm == 2 and rym == 2

    states = parse_shield(shield_json, VISIBLE_LABELS, ACTION_LABELS, HIDDEN_LABELS)
    outgoing_action_states = move_actions_to_outgoing_transitions(states, initial_cond)

    same_as_previous = False
    while not same_as_previous:
        combined_states = combine_identical_states(outgoing_action_states)
        bad_states_removed = eliminate_bad_states(combined_states)

        if len(bad_states_removed) == len(outgoing_action_states):
            same_as_previous = True

        outgoing_action_states = bad_states_removed

    output_json = outgoing_action_shield_to_json(outgoing_action_states, VISIBLE_LABELS, ACTION_LABELS, HIDDEN_LABELS)

    with open(name + ".shield2", "w") as file:
        json.dump(output_json, file, indent=1)

    decentralized_shield = decentralize_shield(outgoing_action_states, (5, 5))

    dec_shield_json = dec_shield_to_json(decentralized_shield, VISIBLE_LABELS, HIDDEN_LABELS)
    with open(name + ".dec_shield", "w") as file:
        json.dump(dec_shield_json, file, indent=1)
